import numpy as np
import matplotlib.pyplot as plt
import matplotlib
matplotlib.rc('xtick', labelsize=20) 
matplotlib.rc('ytick', labelsize=20) 
from matplotlib import gridspec
from matplotlib.transforms import Bbox, TransformedBbox, blended_transform_factory
from matplotlib.ticker import ScalarFormatter

from mpl_toolkits.axes_grid1.inset_locator import BboxPatch, BboxConnector, BboxConnectorPatch
import seaborn as sns
sns.set()
matplotlib.rc('xtick', labelsize=30) 
matplotlib.rc('ytick', labelsize=30) 
plt.rcParams["figure.autolayout"] = True
def get_mean_std(ress):
    return np.mean(ress, axis=0), np.std(ress, axis =0)



def connect_bbox(bbox1, bbox2,
                 loc1a, loc2a, loc1b, loc2b,
                 prop_lines, prop_patches=None):    ## for creating zoomed in plots
    if prop_patches is None:
        prop_patches = prop_lines.copy()
        prop_patches["alpha"] = prop_patches.get("alpha", 1)*0.2

    c1 = BboxConnector(bbox1, bbox2, loc1=loc1a, loc2=loc2a, **prop_lines)
    c1.set_clip_on(False)
    c2 = BboxConnector(bbox1, bbox2, loc1=loc1b, loc2=loc2b, **prop_lines)
    c2.set_clip_on(False)

    bbox_patch1 = BboxPatch(bbox1, **prop_patches)
    bbox_patch2 = BboxPatch(bbox2, **prop_patches)

    p = BboxConnectorPatch(bbox1, bbox2,
                           loc1a=loc1a, loc2a=loc2a, loc1b=loc1b, loc2b=loc2b,
                           **prop_patches)
    p.set_clip_on(False)

    return c1, c2, bbox_patch1, bbox_patch2, p

def zoom_effect02(ax1, ax2, **kwargs):
    tt = ax1.transScale + (ax1.transLimits + ax2.transAxes)
    trans = blended_transform_factory(ax2.transData, tt)

    mybbox1 = ax1.bbox
    mybbox2 = TransformedBbox(ax1.viewLim, trans)

    prop_patches = kwargs.copy()
    prop_patches["ec"] = "none"
    prop_patches["alpha"] = 0.2

    c1, c2, bbox_patch1, bbox_patch2, p = \
        connect_bbox(mybbox1, mybbox2,
                     loc1a=2, loc2a=3, loc1b=1, loc2b=4, 
                     prop_lines=kwargs, prop_patches=prop_patches)

    ax1.add_patch(bbox_patch1)
    ax2.add_patch(bbox_patch2)
    ax2.add_patch(c1)
    ax2.add_patch(c2)
    ax2.add_patch(p)

    return c1, c2, bbox_patch1, bbox_patch2, p




if __name__ == '__main__':    
    dataset = ['covertype','MagicTelescope','shuttle','mushroom','fashion','Plants']
    palette = sns.color_palette()
    sns.set_style("whitegrid")
    for d in dataset:
        x = range(5000)
        
        fig, ax1  = plt.subplots(1, 1,dpi=500, figsize=(10, 8))

        ee = np.load('./results/eenet_results_{}.npy'.format(d))
        ee_mean, ee_std = get_mean_std(ee)
        ax1.plot(x, ee_mean, color = 'goldenrod', label = 'EE-Net')
        ax1.fill_between(x, ee_mean-ee_std, ee_mean+ee_std, color='goldenrod', alpha=0.05)

        ucb = np.load("./results/NeuralUCB_regret_{}.npy".format(d))
        ucb_mean, ucb_std = get_mean_std(ucb)
        ax1.plot(x, ucb_mean,  color = 'purple', label = 'NeuralUCB')
        ax1.fill_between(x, ucb_mean-ucb_std, ucb_mean+ucb_std,color = 'purple', alpha=0.05)

        ts = np.load("./results/NeuralTS_regret_{}.npy".format(d))
        ts_mean, ts_std = get_mean_std(ts)
        ax1.plot(x, ts_mean,  color = 'red', label = 'NeuralTS')
        ax1.fill_between(x, ts_mean-ts_std, ts_mean+ts_std,color = 'red', alpha=0.05)

        ep = np.load("./results/Neural_epsilon_regret_{}.npy".format(d))
        ep_mean, ep_std = get_mean_std(ep)
        ax1.plot(x, ep_mean,  color = 'saddlebrown', label = r"Neural-$\epsilon$")
        ax1.fill_between(x, ep_mean-ep_std, ep_mean+ep_std, color = 'saddlebrown', alpha=0.05)
        
        ad = np.load("./results/neuralad_results_{}.npy".format(d))
        ad_mean, ad_std = get_mean_std(ad)
        ax1.plot(x, ad_mean, color = 'mediumblue', label = "NeuSquareCB")
        ax1.fill_between(x, ad_mean-ad_std, ad_mean+ad_std, color = 'mediumblue', alpha=0.05)

        adv2 = np.load("./results/neuraladv2_results_{}.npy".format(d))
        adv2_mean, adv2_std = get_mean_std(adv2)
        ax1.plot(x, adv2_mean,  color = 'green', label='NeuFastCB')
        ax1.fill_between(x, adv2_mean-adv2_std, adv2_mean+adv2_std , color = 'green',alpha=0.05)
        
        
        
        fig.supxlabel('Rounds',fontsize=30)
        fig.supylabel('Regret',fontsize=30)
        
        ax1.legend(prop={"size":30})
        fig.suptitle(d,fontsize=30,y=0.95)
        
        
        y_formatter = ScalarFormatter(useMathText=True, useOffset=False)
        y_formatter.set_powerlimits((-1, 1)) 

        ax1.yaxis.set_major_formatter(y_formatter)
      
        
        
        plt.xlim(0, 5000)
        
        plt.margins(x=0)
        plt.subplots_adjust(bottom=0.01)  # Decrease bottom margin
        plt.tight_layout()
        plt.savefig('./figures/regret_{}.pdf'.format(d), dpi=500,bbox_inches='tight')
        plt.savefig('./figures/regret_{}.jpg'.format(d), dpi=500,bbox_inches='tight')
